{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1.设置输入和输出节点的个数,配置神经网络的参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "INPUT_NODE = 784     # 输入节点\n",
    "OUTPUT_NODE = 10     # 输出节点\n",
    "LAYER1_NODE = 500    # 隐藏层数       \n",
    "                              \n",
    "BATCH_SIZE = 100     # 每次batch打包的样本个数        \n",
    "\n",
    "# 模型相关的参数\n",
    "LEARNING_RATE_BASE = 0.8      \n",
    "LEARNING_RATE_DECAY = 0.99    \n",
    "REGULARAZTION_RATE = 0.0001   \n",
    "TRAINING_STEPS = 5000        \n",
    "MOVING_AVERAGE_DECAY = 0.99  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. 定义辅助函数来计算前向传播结果，使用ReLU做为激活函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):\n",
    "    # 不使用滑动平均类\n",
    "    if avg_class == None:\n",
    "        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)\n",
    "        return tf.matmul(layer1, weights2) + biases2\n",
    "\n",
    "    else:\n",
    "        # 使用滑动平均类\n",
    "        layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1))\n",
    "        return tf.matmul(layer1, avg_class.average(weights2)) + avg_class.average(biases2)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3. 定义训练过程。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train(mnist):\n",
    "    x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')\n",
    "    y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')\n",
    "    # 生成隐藏层的参数。\n",
    "    weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))\n",
    "    biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))\n",
    "    # 生成输出层的参数。\n",
    "    weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))\n",
    "    biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))\n",
    "\n",
    "    # 计算不含滑动平均类的前向传播结果\n",
    "    y = inference(x, None, weights1, biases1, weights2, biases2)\n",
    "    \n",
    "    # 定义训练轮数及相关的滑动平均类 \n",
    "    global_step = tf.Variable(0, trainable=False)\n",
    "    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)\n",
    "    variables_averages_op = variable_averages.apply(tf.trainable_variables())\n",
    "    average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2)\n",
    "    \n",
    "    # 计算交叉熵及其平均值\n",
    "    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))\n",
    "    cross_entropy_mean = tf.reduce_mean(cross_entropy)\n",
    "    \n",
    "    # 损失函数的计算\n",
    "    regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)\n",
    "    regularaztion = regularizer(weights1) + regularizer(weights2)\n",
    "    loss = cross_entropy_mean + regularaztion\n",
    "    \n",
    "    # 设置指数衰减的学习率。\n",
    "    learning_rate = tf.train.exponential_decay(\n",
    "        LEARNING_RATE_BASE,\n",
    "        global_step,\n",
    "        mnist.train.num_examples / BATCH_SIZE,\n",
    "        LEARNING_RATE_DECAY,\n",
    "        staircase=True)\n",
    "    \n",
    "    # 优化损失函数\n",
    "    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)\n",
    "    \n",
    "    # 反向传播更新参数和更新每一个参数的滑动平均值\n",
    "    # tf支持进行一次完成多个操作，既需要进行train_step又需要variables_averages_op\n",
    "    # 例如创建一个group，把train_step和variables_averages_op两个操作放在一起进行，等同于以下操作：\n",
    "    # with tf.control_dependencies([train_step, variables_averages_op]):\n",
    "    #     train_op = tf.no_op(name='train')\n",
    "    train_op = tf.group(train_step, variables_averages_op)    \n",
    "\n",
    "    # 计算正确率\n",
    "    # average_y.shape = [None, OUTPUT_NODE]，tf.argmax(average_y, 1)表示返回average_y中最大值的序号\n",
    "    # Signature: tf.argmax(input, axis=None, name=None, dimension=None, output_type=tf.int64)\n",
    "    # Returns the index with the largest value across axes of a tensor. (deprecated arguments)\n",
    "\n",
    "    correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "    \n",
    "    # 初始化会话，并开始训练过程。\n",
    "    with tf.Session() as sess:\n",
    "        tf.global_variables_initializer().run()\n",
    "        validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}\n",
    "        test_feed = {x: mnist.test.images, y_: mnist.test.labels} \n",
    "        \n",
    "        # 循环的训练神经网络。\n",
    "        for i in range(TRAINING_STEPS):\n",
    "            if i % 1000 == 0:\n",
    "                validate_acc = sess.run(accuracy, feed_dict=validate_feed)\n",
    "                print(\"After %d training step(s), validation accuracy using average model is %g \" % (i, validate_acc))\n",
    "            \n",
    "            xs,ys=mnist.train.next_batch(BATCH_SIZE)\n",
    "            sess.run(train_op, feed_dict={x:xs,y_:ys})\n",
    "\n",
    "        test_acc=sess.run(accuracy,feed_dict=test_feed)\n",
    "        print((\"After %d training step(s), test accuracy using average model is %g\" %(TRAINING_STEPS, test_acc)))\n",
    "   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4. 主程序入口，这里设定模型训练次数为5000次。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ../../datasets/MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting ../../datasets/MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting ../../datasets/MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting ../../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "After 0 training step(s), validation accuracy using average model is 0.113 \n",
      "After 1000 training step(s), validation accuracy using average model is 0.9776 \n",
      "After 2000 training step(s), validation accuracy using average model is 0.9814 \n",
      "After 3000 training step(s), validation accuracy using average model is 0.9834 \n",
      "After 4000 training step(s), validation accuracy using average model is 0.9836 \n",
      "After 5000 training step(s), test accuracy using average model is 0.9832\n"
     ]
    }
   ],
   "source": [
    "def main(argv=None):\n",
    "    mnist = input_data.read_data_sets(\"../../datasets/MNIST_data\", one_hot=True)\n",
    "    train(mnist)\n",
    "\n",
    "if __name__=='__main__':\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
